package tk.winshu.shortestpath.algorithm;

import tk.winshu.shortestpath.model.Node;

import java.util.*;

/**
 * Dijkstra算法
 * <p>
 * <a href="https://www.cnblogs.com/biyeymyhjob/archive/2012/07/31/2615833.html">最短路径—Dijkstra算法和Floyd算法</a>
 *
 * @author Jason Krebs
 * @date 2015年2月5日
 */
public class Dijkstra {

    private Set<Node> openNodes;
    private Set<Node> closeNodes;
    /**
     * 用于存放节点到 source 节点的距离
     */
    private Map<Node, Integer> distances;
    private Map<Node, Node> predecessors;

    /**
     * 是否有向图
     */
    private boolean isDirectedGraph;

    /**
     * @param isDirectedGraph 是否有向图
     */
    public Dijkstra(boolean isDirectedGraph) {
        this.isDirectedGraph = isDirectedGraph;
        this.openNodes = new HashSet<>();
        this.closeNodes = new HashSet<>();
        this.distances = new HashMap<>();
        this.predecessors = new HashMap<>();
    }

    public void preProcess(Node source) {
        this.openNodes.clear();
        this.closeNodes.clear();
        this.distances.clear();
        this.predecessors.clear();

        // source -> source : distance = 0
        distances.put(source, 0);
        openNodes.add(source);

        while (!openNodes.isEmpty()) {
            // 获取最短路径的节点
            Node node = pickMinimumDistanceNode();
            // 将该节点加入已处理集合中
            closeNodes.add(node);
            // 将该节点从未处理集合中移除
            openNodes.remove(node);
            // 查找最短距离并更新
            processMinimalDistances(node);
        }
    }

    private void processMinimalDistances(Node node) {
        List<Node> neighborNodes = getNeighbors(node);
        for (Node target : neighborNodes) {
            // 试图找出更短的路径
            int newDistance = getDistance(node) + (int) node.distanceTo(target);
            if (getDistance(target) > newDistance) {
                // 更新距离信息
                distances.put(target, newDistance);
                predecessors.put(target, node);
                openNodes.add(target);
            }
        }
    }

    /**
     * 在未处理节点中找出最短距离的节点
     */
    private Node pickMinimumDistanceNode() {
        Node minimum = null;
        for (Node node : openNodes) {
            if (minimum == null) {
                minimum = node;
                continue;
            }
            if (getDistance(node) < getDistance(minimum)) {
                minimum = node;
            }
        }
        return minimum;
    }

    /**
     * 获取指定节点的所有邻居节点
     */
    private List<Node> getNeighbors(Node node) {
        List<Node> neighbors = new ArrayList<>();
        List<Node> allNeighbors = isDirectedGraph ? node.getEndNodes() : node.getAllRelations();
        for (Node neighbor : allNeighbors) {
            // 排除已经在闭合节点中的邻居节点
            if (!closeNodes.contains(neighbor)) {
                neighbors.add(neighbor);
            }
        }
        return neighbors;
    }

    /**
     * 获取路径信息
     */
    public List<Node> getPath(Node target) {
        List<Node> path = new LinkedList<>();

        Node step = target;
        if (predecessors.get(step) != null) {
            path.add(step);
        }
        while (predecessors.get(step) != null) {
            step = predecessors.get(step);
            path.add(step);
        }
        // 调整顺序
        Collections.reverse(path);
        return path;
    }

    /**
     * 获取指定节点的距离，如果找不到，则认为是无穷大
     */
    private int getDistance(Node node) {
        Integer d = distances.get(node);
        return d == null ? Integer.MAX_VALUE : d;
    }

    /**
     * 查找起始节点->目标节点的路径
     * <p>
     * 该方法每次都会预处理节点信息，再查找路径
     */
    public List<Node> find(Node source, Node target) {
        preProcess(source);
        return getPath(target);
    }

}
